{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 简化推理"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "1. {class}`~tvm.relay.transform.SimplifyInference` (见 `tvm/src/relay/transforms/simplify_inference.cc`)\n",
    "\n",
    "```c++\n",
    "InferenceSimplifier()\n",
    "      : batch_norm_op_(Op::Get(\"nn.batch_norm\")),\n",
    "        dropout_op_(Op::Get(\"nn.dropout\")),\n",
    "        instance_norm_op_(Op::Get(\"nn.instance_norm\")),\n",
    "        layer_norm_op_(Op::Get(\"nn.layer_norm\")),\n",
    "        group_norm_op_(Op::Get(\"nn.group_norm\")),\n",
    "        l2_norm_op_(Op::Get(\"nn.l2_normalize\")) {}\n",
    "\n",
    "```"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "from tvm.ir import IRModule, structural_equal\n",
    "from tvm import relay as rly\n",
    "from tvm.relay.transform import SimplifyInference, InferType"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "定义简单的 batch-norm(可以参考：[batch-norm](https://zh.d2l.ai/chapter_convolutional-modern/batch-norm.html))：\n",
    "\n",
    "$$\n",
    "\\mathrm{BN}(\\mathbf{x}) = \\boldsymbol{\\gamma} \\odot \\frac{\\mathbf{x} - \\hat{\\boldsymbol{\\mu}}_\\mathcal{B}}{\\hat{\\boldsymbol{\\sigma}}_\\mathcal{B}} + \\boldsymbol{\\beta}.\n",
    "$$\n",
    "\n",
    "其中 $\\hat{\\boldsymbol{\\mu}}_\\mathcal{B}$ 和 $\\hat{\\boldsymbol{\\sigma}}_\\mathcal{B}$ 分别是小批量 $\\mathcal{B}$ 的样本均值和样本标准差。\n",
    "\n",
    "$$\n",
    "\\begin{split}\\begin{aligned} \\hat{\\boldsymbol{\\mu}}_\\mathcal{B} &= \\frac{1}{|\\mathcal{B}|} \\sum_{\\mathbf{x} \\in \\mathcal{B}} \\mathbf{x},\\\\\n",
    "\\hat{\\boldsymbol{\\sigma}}_\\mathcal{B}^2 &= \\frac{1}{|\\mathcal{B}|} \\sum_{\\mathbf{x} \\in \\mathcal{B}} (\\mathbf{x} - \\hat{\\boldsymbol{\\mu}}_{\\mathcal{B}})^2 + \\epsilon.\\end{aligned}\\end{split}\n",
    "$$\n",
    "\n",
    "应用标准化后，生成的小批量的平均值为 $0$ 和单位方差为 $1$。由于单位方差是主观的选择，因此通常需要包含拉伸参数（scale） $\\boldsymbol{\\gamma}$ 和偏移参数（shift） $\\boldsymbol{\\beta}$，它们的形状与 $\\mathbf{x}$ 相同。请注意，$\\boldsymbol{\\gamma}$ 和 $\\boldsymbol{\\beta}$ 是需要与其他模型参数一起学习的参数。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "dim = 4\n",
    "nstep = 1 # 1, 3\n",
    "axis = 1 # 0, 1\n",
    "dtype = \"float16\"\n",
    "eps = 0.01\n",
    "ttype1 = rly.TensorType(tuple(10 for i in range(dim)), dtype)\n",
    "ttype2 = rly.TensorType((10,), dtype)\n",
    "x = rly.var(\"x\", ttype1)\n",
    "beta = rly.var(\"beta\", ttype2)\n",
    "gamma = rly.var(\"gamma\", ttype2)\n",
    "moving_var = rly.var(\"moving_var\", ttype2)\n",
    "moving_mean = rly.var(\"moving_mean\", ttype2)\n",
    "y1, y2 = x, x\n",
    "for _ in range(nstep):\n",
    "    y1, _, _ = rly.nn.batch_norm(\n",
    "        y1 + rly.const(1, dtype),\n",
    "        gamma,\n",
    "        beta,\n",
    "        moving_mean,\n",
    "        moving_var,\n",
    "        epsilon=eps,\n",
    "        axis=axis,\n",
    "    )\n",
    "    y1 = rly.nn.dropout(y1)\n",
    "mod = IRModule.from_expr(y1)\n",
    "simplify = SimplifyInference()\n",
    "mod = InferType()(mod)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "tvmz",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.12"
  },
  "orig_nbformat": 4
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
